# #################################################################################################################
#  How to Enable Tensorboard in my computer:                                                                    ###
# A) In linux:                                                                                                  ###
#       1) conda activate torch_venv                                                                            ###
#       2) tensorboard --logdir=<tb dir>                                                                        ###
#       3) wait for final print and find the port of tensorboard                                                ###
# B) In Windows:                                                                                                ###
#       1) open cmd and type: ssh -L <tensor board port>:localhost:<tensor board port> XXXX@132.72.65.199      ###
#       2) Enter password                                                                                       ###
#       3) in firefox: localhost:<tensorboard port>                                                             ###
#                                                                                                               ###
# https://serverfault.com/questions/1004529/access-an-http-server-as-localhost-from-an-external-pc-over-ssh     ###
#                                                                                                               ###
# #################################################################################################################
import numpy as np
import torch
from torch.nn import Module, MSELoss
from torch.utils.data import DataLoader
from torch.optim import Adam, Optimizer
from typing import Optional, Callable, Any, Tuple, List, Dict, Union
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR

from Utils.Constants import TBNames
from Utils import logger, config as cfg
from ModelsUtils.ModelCallbacks import EarlyStoppingCallback, ModelCheckpointCallback


class ModelTrainer:
    def __init__(self, model: Module, tb_writer: Optional[SummaryWriter],
                 optimizer: Optional[Optimizer] = None, use_scheduler: bool = True, scheduler: Optional[Any] = None,
                 loss_func: Union[Callable[[torch.TensorType, torch.TensorType], Any], Callable[[], Callable]] = MSELoss,
                 convert_to_double: bool = False):
        """
        Class for easy training and evaluating models
        :param model: pytorch model
        :param tb_writer: tensorboard writer object if exists from model, will try to extrac from model if None
        :param optimizer: Optimizer object for training, if None will use adam optimizer
        :param use_scheduler: False for no LR scheduler, True for using the default or parameter scheduler
        :param scheduler: learning rate scheduler, if None and use_use_scheduler is True will use default StepScheduler
        :param loss_func: loss function
        :param convert_to_double: should convert inputs to double data type
        """
        self._model: Module = model
        self._convert_to_double = convert_to_double
        if tb_writer is None:
            try:
                tb_writer = model.tb_writer()
            except AttributeError:
                pass
        self._tb_writer = tb_writer
        self._optim = Adam(self._model.parameters(), lr=0.001) if optimizer is None else optimizer
        if use_scheduler:
            self._scheduler = StepLR(self._optim, step_size=20, gamma=0.15) if scheduler is None else scheduler
        else:
            self._scheduler = None

        self._loss_name = None
        try:
            self._loss_name = '_' + loss_func.__name__
        except AttributeError or TypeError:
            pass
        if self._loss_name is None:
            try:
                self._loss_name = '_' + loss_func._get_name()
            except AttributeError or TypeError:
                self._loss_name = ''

        try:
            self._loss_func = loss_func()
        except TypeError:
            self._loss_func = loss_func

        self._times_called_eval = 0

    def __single_step(self, inputs, labels, metrics_funcs):
        """
        Calculate single step.
        :param inputs:
        :param labels:
        :param metrics_funcs:
        :return: Tuple of step loss and metric_func, Tuple of None if preds has nan values
        """
        preds = self._model(inputs).squeeze()
        if torch.any(torch.isnan(preds)):
            return None, None
        else:
            curr_loss = self._loss_func(preds, labels)
            curr_metric = [float(metric(preds, labels)) for metric in metrics_funcs]
        return curr_loss, curr_metric

    def _single_epoch_run(self, dataloader: DataLoader, is_train: bool,
                          metrics_funcs: List[Callable[[torch.TensorType, torch.TensorType], Any]]) -> Tuple[float, np.ndarray]:
        """
        Runs a single epoch from a DataLoader object
        :param dataloader:
        :param is_train: True if is training and should apply gradient steps
        :return: curr step loss, nan if there was a problem with DataLoader
        """
        loss = 0
        step = 0
        metrics_values = np.zeros(len(metrics_funcs))
        skipped_steps = 0
        if is_train:
            self._model.train()
        else:
            self._model.eval()

        for step, (inputs, labels) in enumerate(dataloader, start=1):
            # logger().log('', 'STEP: ', step)
            inputs = inputs.to(cfg.device)
            labels = labels.to(cfg.device)
            if self._convert_to_double:
                inputs = inputs.double()
                labels = labels.double()
            if torch.any(torch.isnan(labels)):
                skipped_steps += 1
                continue

            if not is_train:
                with torch.no_grad():
                    curr_loss, curr_metrics_results = self.__single_step(inputs, labels, metrics_funcs)
                    if curr_loss is None:
                        skipped_steps += 1
                        continue
                    loss += float(curr_loss)
            else:
                curr_loss, curr_metrics_results = self.__single_step(inputs, labels, metrics_funcs)
                if curr_loss is None:
                    skipped_steps += 1
                    continue
                self._optim.zero_grad()
                curr_loss.backward()
                self._optim.step()
                loss += float(curr_loss)
            metrics_values += curr_metrics_results

        if is_train and self._scheduler is not None:
            self._scheduler.step()

        if step == 0:
            logger().warning('ModelTrainer::_single_epoch_run', "Didn't do any steps in current epoch, is_train: ", is_train)
            return np.nan, np.nan
        if skipped_steps != 0:
            logger().force_log_and_print('ModelTrainer::_single_epoch_run', f'Number of skipped steps: {skipped_steps} from: {step}')

        total_steps = step - skipped_steps
        if total_steps == 0:
            logger().warning('ModelTrainer::_single_epoch_run', f'Zero steps: step={step}, skipped: {skipped_steps},'
                                                                f'loss: {loss} - setting to 1')
            total_steps = 1

        return loss / total_steps, metrics_values / total_steps

    def _calc_metric(self, dataloader: DataLoader, metric: Callable[[torch.TensorType, torch.TensorType], Any]):
        """
        Calculate a metric_func if needed
        :param dataloader: data for calculation
        :param metric:
        :return:
        """
        if metric is None:
            return -1
        self._model.eval()
        all_preds = list()
        all_labels = list()
        for inp, labels in dataloader:
            preds = self._model(inp).cpu().detach().numpy()
            all_labels.append(labels)
            all_preds.append(preds)

        all_preds = torch.concat(all_preds) if len(all_preds) > 1 else all_preds[0]
        all_labels = torch.concat(all_labels) if len(all_labels) > 1 else all_labels[0]
        return metric(all_preds.detach(), all_labels.detach())

    def _log_tb(self, epoch, losses: List[float], losses_names: List[str],
                metrics: Dict[str, List[Union[List[str], List[float]]]]):
        """
        Logs multiple info to tensorboard if writer exists
        :param epoch:
        :param losses: List of loss values
        :param losses_names: List of names for loss values
        :param metrics: Dictionary for metrics to log. The keys are metrics groups names. Each value is a list made
                        from 2 lists, the first sub list is the names of each metric, second sub list is the metrics
                        values, e.g., {'G': [['g_train','g_val'],[0.3, 0.1]]} - One metric group called g with two
                        values 0.3 for train and 0.1 for validation
        :return:
        """
        if self._tb_writer is not None:
            for curr_loss, curr_name in zip(losses, losses_names):
                self._tb_writer.add_scalar(curr_name, curr_loss, epoch)  # log each loss individually
            self._tb_writer.add_scalars('LOSS', {k: v for k, v in zip(losses_names, losses)}, epoch)
            if metrics is not None and len(metrics) > 0:
                for group_name, curr_group in metrics.items():
                    for curr_name, curr_metric in zip(curr_group[0], curr_group[1]):
                        self._tb_writer.add_scalar(curr_name, curr_metric, epoch)   # log each metric individually
                    self._tb_writer.add_scalars(group_name, {k: v for k, v in zip(curr_group[0], curr_group[1])}, epoch)

    @staticmethod
    def __set_up_metrics(metrics: Optional[List[Callable[[torch.TensorType, torch.TensorType], Any]]],
                         metrics_names: Optional[List[str]]) -> Tuple[List[Callable[[torch.TensorType, torch.TensorType], Any]], List[str]]:
        """
        Setups the lists of metrics funcs and metrics names to avoid any problems
        :param metrics:
        :param metrics_names:
        :return: list of metrics funcs, list of metrics names
        """
        if metrics is None or len(metrics) == 0:
            return [lambda *args: -1], list()
        if metrics_names is None or len(metrics) > len(metrics_names):
            metrics_names = list() if metrics_names is None else metrics_names
            try:
                metrics_names = metrics_names + [func.__name__ for func in metrics[len(metrics_names):]]
            except AttributeError:      # Some torch funcs don't have names :(
                metrics_names = metrics_names + [f'metric_{idx}' for idx in range(len(metrics) - len(metrics_names))]

        return metrics, metrics_names

    def fit(self, epochs: int, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None,
            metrics_funcs: Optional[List[Callable[[torch.TensorType, torch.TensorType], Any]]] = None,
            metrics_names: Optional[List[str]] = None,
            checkpoint_cb: Optional[ModelCheckpointCallback] = None,
            early_stopping: Optional[EarlyStoppingCallback] = None):
        """
        Fit the model
        :param epochs:
        :param train_dataloader:
        :param val_dataloader:
        :param metrics_funcs:
        :param metrics_names:
        :param checkpoint_cb:
        :param early_stopping:
        :return:
        """
        has_val = val_dataloader is not None
        metrics_funcs, metrics_names = self.__set_up_metrics(metrics_funcs, metrics_names)
        curr_val_loss = -1
        curr_val_metrics = [-1]*len(metrics_names)

        if checkpoint_cb is not None:
            checkpoint_cb.model = self._model

        for epoch in range(epochs):
            logger().force_log_and_print('ModelTrainer::fit', f'Epoch: {epoch}')
            curr_train_loss, curr_train_metrics = self._single_epoch_run(train_dataloader, True, metrics_funcs)
            logger().force_log_and_print('ModelTrainer::fit', f'Training loss{self._loss_name}: {curr_train_loss: .3f}, Metrics: ',
                                         {name: round(value, 3) for name, value in zip(metrics_names, curr_train_metrics)})

            if has_val:
                curr_val_loss, curr_val_metrics = self._single_epoch_run(val_dataloader, False, metrics_funcs)
                logger().force_log_and_print('ModelTrainer::fit', f'Validation loss{self._loss_name}: {curr_val_loss: .3f}, Metrics: ',
                                             {name: round(value, 3) for name, value in zip(metrics_names, curr_val_metrics)})

            tb_metrics_dict = {name: [[f'{name}/Train', f'{name}/Validation'], [curr_train_metrics[idx], curr_val_metrics[idx]]]
                               for idx, name in enumerate(metrics_names)}

            print('')       # This is just for pretty printing on console or sbatch out file
            self._log_tb(epoch, [curr_train_loss, curr_val_loss],
                         [f'{TBNames.LOSS_TRAIN}{self._loss_name}', f'{TBNames.LOSS_VAL}{self._loss_name}'],
                         tb_metrics_dict)
            if checkpoint_cb is not None:
                checkpoint_cb(curr_train_loss, curr_val_loss, curr_train_metrics[0], curr_val_metrics[0])
            if early_stopping is not None and\
                    early_stopping(curr_train_loss, curr_val_loss, curr_train_metrics[0], curr_val_metrics[0]):
                logger().force_log_and_print('ModelTrainer::fit', 'Early Stopping')
                break

        return self._model

    def evaluate(self, test_dataloader: DataLoader,
                 metrics_funcs: Optional[List[Callable[[torch.TensorType, torch.TensorType], Any]]] = None,
                 metrics_names: Optional[List[str]] = None) -> Tuple[float, Any]:
        self._model.eval()
        self._times_called_eval += 1

        metrics_funcs, metrics_names = self.__set_up_metrics(metrics_funcs, metrics_names)
        if metrics_names is not None and len(metrics_names) > 0:
            metrics_names = [f'EVAL/{curr_name}' for curr_name in metrics_names]
        loss, metrics_results = self._single_epoch_run(test_dataloader, False, metrics_funcs)
        logger().force_log_and_print('ModelTrainer::evaluate', f'Loss: {loss: 3f}, Metrics: ',
                                     {name: round(val, 3) for name, val in zip(metrics_names, metrics_results)})

        metrics_log_dict = {'all_metrics': [metrics_names, metrics_results]}
        self._log_tb(self._times_called_eval, [loss], ['LOSS/Eval'], metrics_log_dict)
        return loss, metrics_results
